import os
import sys

sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))

import torch.optim as optim
from torch import nn as nn

from trainer.coteaching import CoTeachingTrainer
from utils import get_configs, get_datasets, get_logger, get_model, seed_everything


def get_optimizer(config, model):
    optimizer_type, learning_rate = (
        config["train"]["optimizer_type"],
        config["train"]["learning_rate"],
    )

    if optimizer_type == "sgd":
        opt_ = optim.SGD(
            model.parameters(),
            lr=learning_rate,
            momentum=0.9,
            weight_decay=5e-4,
        )
    elif optimizer_type == "adam":
        opt_ = optim.Adam(model.parameters(), lr=learning_rate)
    elif optimizer_type == "rmsprop":
        opt_ = optim.RMSprop(model.parameters(), lr=learning_rate)
    else:
        raise ValueError(f"Unsupported optimizer type: {optimizer_type}")

    return opt_


def get_scheduler(config, optimizer):
    # config = full_package["config"]
    scheduler_type = config["train"]["scheduler_type"]

    if scheduler_type == "step":
        scheduler_gamma, scheduler_step_size = (
            config["train"]["scheduler_gamma"],
            config["train"]["scheduler_step_size"],
        )
        scheduler = optim.lr_scheduler.StepLR(
            optimizer,
            step_size=scheduler_step_size,
            gamma=scheduler_gamma,
        )
    elif scheduler_type == "cosine":
        scheduler_T_max = config["train"]["scheduler_T_max"]
        scheduler = optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=scheduler_T_max
        )
    else:
        raise ValueError(f"Unsupported scheduler type: {scheduler_type}")
    return scheduler


def main(config):
    print("==> Preparing data..")
    trainset, testset, valset, num_classes = get_datasets(config)
    model = get_model(config, num_classes)
    model2 = get_model(config, num_classes)
    model.to("cuda:0")
    model2.to("cuda:0")
    optimizer = get_optimizer(config, model)
    optimizer2 = get_optimizer(config, model2)
    logger = get_logger(config)
    trainer = CoTeachingTrainer(
        config,
        model,
        model2,
        logger,
        trainset,
        testset,
        nn.CrossEntropyLoss(),
        optimizer,
        None,
        optimizer2,
    )
    trainer.run()


if __name__ == "__main__":
    configs = get_configs()
    seed_everything(configs["general"]["np_seed"])
    main(configs)
